660e39
@@ -111,7 +111,7 @@
   private final AtomicBoolean isShutdown = new AtomicBoolean(false);
   // Tracks appMasters to which heartbeats are being sent. This should not be used for any other
   // messages like taskKilled, etc.
-  private final Map<QueryIdentifier, AMNodeInfo> knownAppMasters = new HashMap<>();
+  private final Map<QueryIdentifier, Map<LlapNodeId, AMNodeInfo>> knownAppMasters = new HashMap<>();
   volatile ListenableFuture<Void> queueLookupFuture;
   private final DaemonId daemonId;
 
@@ -208,11 +208,16 @@
public void registerTask(String amLocation, int port, String umbilicalUser,
     // and discard AMNodeInfo instances per query.
     synchronized (knownAppMasters) {
       LlapNodeId amNodeId = LlapNodeId.getInstance(amLocation, port);
-      amNodeInfo = knownAppMasters.get(queryIdentifier);
+      Map<LlapNodeId, AMNodeInfo> amNodeInfoPerQuery = knownAppMasters.get(queryIdentifier);
+      if (amNodeInfoPerQuery == null) {
+        amNodeInfoPerQuery = new HashMap<>();
+        knownAppMasters.put(queryIdentifier, amNodeInfoPerQuery);
+      }
+      amNodeInfo = amNodeInfoPerQuery.get(amNodeId);
       if (amNodeInfo == null) {
-        amNodeInfo = new AMNodeInfo(amNodeId, umbilicalUser, jobToken, queryIdentifier,
-            retryPolicy, retryTimeout, socketFactory, conf);
-        knownAppMasters.put(queryIdentifier, amNodeInfo);
+        amNodeInfo = new AMNodeInfo(amNodeId, umbilicalUser, jobToken, queryIdentifier, retryPolicy,
+          retryTimeout, socketFactory, conf);
+        amNodeInfoPerQuery.put(amNodeId, amNodeInfo);
         // Add to the queue only the first time this is registered, and on
         // subsequent instances when it's taken off the queue.
         amNodeInfo.setNextHeartbeatTime(System.currentTimeMillis() + heartbeatInterval);
@@ -233,7 +238,7 @@
public void unregisterTask(String amLocation, int port, QueryIdentifier queryIde
     }
     AMNodeInfo amNodeInfo;
     synchronized (knownAppMasters) {
-      amNodeInfo = knownAppMasters.get(queryIdentifier);
+      amNodeInfo = getAMNodeInfo(amLocation, port, queryIdentifier);
       if (amNodeInfo == null) {
         LOG.info(("Ignoring duplicate unregisterRequest for am at: " + amLocation + ":" + port));
       } else {
@@ -249,7 +254,7 @@
public void taskKilled(String amLocation, int port, String umbilicalUser, Token<
     LlapNodeId amNodeId = LlapNodeId.getInstance(amLocation, port);
     AMNodeInfo amNodeInfo;
     synchronized (knownAppMasters) {
-      amNodeInfo = knownAppMasters.get(queryIdentifier);
+      amNodeInfo = getAMNodeInfo(amLocation, port, queryIdentifier);
       if (amNodeInfo == null) {
         amNodeInfo = new AMNodeInfo(amNodeId, umbilicalUser, jobToken, queryIdentifier, retryPolicy, retryTimeout, socketFactory,
           conf);
@@ -277,20 +282,22 @@
public void onFailure(Throwable t) {
   public void queryComplete(QueryIdentifier queryIdentifier) {
     if (queryIdentifier != null) {
       synchronized (knownAppMasters) {
-        AMNodeInfo amNodeInfo = knownAppMasters.remove(queryIdentifier);
+        LOG.debug("Query complete received for {}", queryIdentifier);
+        Map<LlapNodeId, AMNodeInfo> amNodeInfoPerQuery = knownAppMasters.remove(queryIdentifier);
 
         // The AM can be used for multiple queries. This is an indication that a single query is complete.
         // We don't have a good mechanism to know when an app ends. Removing this right now ensures
         // that a new one gets created for the next query on the same AM.
-        if (amNodeInfo != null) {
-          amNodeInfo.setIsDone(true);
+        if (amNodeInfoPerQuery != null) {
+          LOG.debug("Removed following AMs due to query complete:");
+          for (AMNodeInfo amNodeInfo : amNodeInfoPerQuery.values()) {
+            amNodeInfo.setIsDone(true);
+            LOG.debug(amNodeInfo.toString());
+          }
         }
         // TODO: not stopping umbilical explicitly as some taskKill requests may get scheduled during queryComplete
         // which will be using the umbilical. HIVE-16021 should fix this, until then leave umbilical open and wait for
         // it to be closed after max idle timeout (10s default)
-        if (LOG.isDebugEnabled()) {
-          LOG.debug("Query complete received. Removed {}.", amNodeInfo);
-        }
       }
     }
   }
@@ -419,9 +426,34 @@
protected Void callInternal() {
     }
   }
 
+  protected LlapTaskUmbilicalProtocol createUmbilical(final AMNodeInfo amNodeInfo)
+    throws IOException, InterruptedException {
+    final InetSocketAddress address = NetUtils.createSocketAddrForHost(
+      amNodeInfo.amNodeId.getHostname(), amNodeInfo.amNodeId.getPort());
+    SecurityUtil.setTokenService(amNodeInfo.jobToken, address);
+    UserGroupInformation ugi = UserGroupInformation.createRemoteUser(amNodeInfo.umbilicalUser);
+    ugi.addToken(amNodeInfo.jobToken);
+    return ugi.doAs(new PrivilegedExceptionAction<LlapTaskUmbilicalProtocol>() {
+      @Override
+      public LlapTaskUmbilicalProtocol run() throws Exception {
+        return RPC
+          .getProxy(LlapTaskUmbilicalProtocol.class, LlapTaskUmbilicalProtocol.versionID,
+            address, UserGroupInformation.getCurrentUser(), amNodeInfo.conf,
+            amNodeInfo.socketFactory, (int) (amNodeInfo.timeout));
+      }
+    });
+  }
 
+  private AMNodeInfo getAMNodeInfo(String amHost, int amPort, QueryIdentifier queryId) {
+    Map<LlapNodeId, AMNodeInfo> amNodeInfoPerQuery = knownAppMasters.get(queryId);
+    if (amNodeInfoPerQuery != null) {
+      LlapNodeId amNodeId = LlapNodeId.getInstance(amHost, amPort);
+      return amNodeInfoPerQuery.get(amNodeId);
+    }
+    return null;
+  }
 
-  private static class AMNodeInfo implements Delayed {
+  protected class AMNodeInfo implements Delayed {
     // Serves as lock for itself.
     private final Set<TezTaskAttemptID> tasks = new HashSet<>();
     private final String umbilicalUser;
@@ -457,20 +489,7 @@
public AMNodeInfo(LlapNodeId amNodeId, String umbilicalUser,
 
     synchronized LlapTaskUmbilicalProtocol getUmbilical() throws IOException, InterruptedException {
       if (umbilical == null) {
-        final InetSocketAddress address =
-            NetUtils.createSocketAddrForHost(amNodeId.getHostname(), amNodeId.getPort());
-        SecurityUtil.setTokenService(this.jobToken, address);
-        UserGroupInformation ugi = UserGroupInformation.createRemoteUser(umbilicalUser);
-        ugi.addToken(jobToken);
-        umbilical = ugi.doAs(new PrivilegedExceptionAction<LlapTaskUmbilicalProtocol>() {
-          @Override
-          public LlapTaskUmbilicalProtocol run() throws Exception {
-            return RPC
-                .getProxy(LlapTaskUmbilicalProtocol.class, LlapTaskUmbilicalProtocol.versionID,
-                    address, UserGroupInformation.getCurrentUser(), conf, socketFactory,
-                    (int) timeout);
-          }
-        });
+        umbilical = createUmbilical(this);
       }
       return umbilical;
     }
